from __future__ import annotations

import numpy.random
import torch
import torch.backends.cudnn as cudnn
import torch.cuda as cuda
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader

def setup_seed(seed):
    torch.manual_seed(seed)
    cuda.manual_seed_all(seed)
    cuda.manual_seed(seed)
    numpy.random.seed(seed)
    cudnn.deterministic = True

def print_dict(d: dict, title: str):
    if len(d) == 0:
        return
    else:
        print(f'==================== {title} ====================')
        for key, val in d.items():
            print(f'{key}: {val}')

def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            setattr(func, k, kwargs[k])
        return func
    return decorate

@static_vars(best_te_acc=0)
@torch.no_grad()
def eval_perf(model: nn.Module, te_dataloader: DataLoader, loss_fun):
    total_loss = 0.0
    total_correct = 0

    model.cuda()
    model.eval()
    for inputs, targets in te_dataloader:
        inputs: torch.Tensor = inputs.cuda()
        targets: torch.Tensor = targets.cuda()
        outputs: torch.Tensor = model(inputs)
        total_loss += loss_fun(outputs, targets).item()
        predictions = torch.argmax(outputs, dim=1)
        total_correct += (predictions == targets).long().sum().item()
    model.cpu()
    
    n_data = len(te_dataloader.dataset)  # type: ignore
    te_loss = total_loss / n_data
    te_acc = total_correct / n_data * 100.0

    eval_perf.best_te_acc = max(eval_perf.best_te_acc, te_acc)
    perf_stats = {
        'te_loss': te_loss,
        'te_acc': te_acc,
        'best_te_acc': eval_perf.best_te_acc,
    }

    return perf_stats

def bisection(min_val: float, max_val: float, tol: float, f):
    assert min_val < max_val
    # assert f(min_val) * f(max_val) <= 0
    while max_val - min_val > tol:
        mid_val = (min_val + max_val) / 2
        if f(mid_val) * f(min_val) > 0:
            min_val = mid_val
        else:
            max_val = mid_val
    return min_val